# ======================================================================
# Copyright CERFACS (March 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

import logging

logger = logging.getLogger("qaths.pbo_rules.IBMQ.gate_definitions")

try:
    from numpy import pi
    from qaths.pbo_rules.IBMQ.base_gates import U1, U2, U3, CX
    from qat.pbo import VAR, GraphCircuit

    _theta = VAR()
    _phi = VAR()
    _lamb = VAR()

    _div2_var = VAR.add_function(lambda x: x / 2)

    _gate_definitions = (
        ([("X", [0])], [("U3", [0], pi, 0, pi)]),
        ([("Y", [0])], [("U3", [0], pi, pi / 2, pi / 2)]),
        ([("Z", [0])], [("U1", [0], pi)]),
        ([("H", [0])], [("U2", [0], 0, pi)]),
        ([("S", [0])], [("U1", [0], pi / 2)]),
        ([("T", [0])], [("U1", [0], pi / 4)]),
        ([("RX", [0], _theta)], [("U3", [0], _theta, -pi / 2, pi / 2)]),
        ([("RY", [0], _theta)], [("U3", [0], _theta, 0, 0)]),
        ([("RZ", [0], _phi)], [("U1", [0], _phi)]),
        ([("PH", [0], _theta)], [("U1", [0], _theta)]),
        ([("CNOT", [0, 1])], [("CX", [0, 1])]),
        (
            [("C-PH", [0, 1], _lamb)],
            [
                ("U1", [0], _div2_var(_lamb)),
                ("CX", [0, 1]),
                ("U1", [1], _div2_var(_lamb)),
                ("CX", [0, 1]),
                ("U1", [1], _div2_var(_lamb)),
            ],
        ),
        (
            [("CCNOT", [0, 1, 2])],
            [
                ("U2", [2], 0, pi),  # H 2
                ("CX", [1, 2]),
                ("U1", [2], -pi / 4),  # T.dag() 2
                ("CX", [0, 2]),
                ("U1", [2], pi / 4),  # T 2
                ("CX", [1, 2]),
                ("U1", [2], -pi / 4),  # T.dag() 2
                ("CX", [0, 2]),
                ("U1", [1], pi / 4),  # T 1
                ("U1", [2], pi / 4),  # T 2
                ("U2", [2], 0, pi),  # H 2
                ("CX", [0, 2]),
                ("U1", [0], pi / 4),  # T 0
                ("U1", [1], -pi / 4),  # T.dag() 1
                ("CX", [0, 1]),
            ],
        ),
    )

    def _get_gate_name(gate, gate_dict) -> str:
        """Return the name of the given gate."""
        if gate.syntax is not None:
            return gate.syntax.name
        elif gate.is_ctrl:
            return "C-{}".format(_get_gate_name(gate_dict[gate.subgate], gate_dict))
        else:
            raise RuntimeError(
                "Unsupported operation in get_gate_name: {}".format(gate)
            )

    def circ_to_ibmq_basis(circ, filter_before_replace: bool = True):
        """Take a circuit and translate it to IBMQ basis: {U1, U2, U3, CX}

        :param circ: the circuit to translate to IBMQ basis.
        :param filter_before_replace: Avoid applying transformation rules A -> B when
            the gate A is not present in the given circuit.
        :return: A quantum circuit equivalent to the one given in parameter (i.e. same
            effect on the given quantum state) BUT translated to the IBMQ basis.
        """

        gate_names_in_circ = set()
        if filter_before_replace:
            gate_names_in_circ = {
                _get_gate_name(circ.gateDic[gate_id], circ.gateDic)
                for gate_id in circ.gateDic
            }

        graph = GraphCircuit()
        for ag in [U1, U2, U3, CX]:
            graph.add_abstract_gate(ag)
        graph.load_circuit(circ)

        for pattern in _gate_definitions:
            # If asked by the caller, avoid patterns that replace a gate that is not
            # in the circuit.
            if filter_before_replace and pattern[0][0][0] not in gate_names_in_circ:
                continue
            while graph.replace_pattern(*pattern):
                continue
        return graph.to_circ()


except ImportError:
    logger.warning(
        "Pattern-based optimizer (PBO) not found. You are probably using "
        "MyQLM that does not include this feature. The circ_to_ibmq_basis function "
        "will have no effect."
    )

    def circ_to_ibmq_basis(circ, filter_before_replace: bool = True):
        return circ
